package com.github.hgkmail.hello.leetcode101.pointer.tree;

import com.github.hgkmail.hello.leetcode101.base.CommonUtil;
import com.github.hgkmail.hello.leetcode101.base.TreeNode;

import java.util.HashMap;
import java.util.Map;

public class LC437PathSum3 {
    //先序遍历，targetSum减去root，剩下的去子树找
    public int pathSumFrom(TreeNode root, long targetSum) {
        if (root==null) {
            return 0;
        }
        int count=0;
        if (root.val==targetSum) {
            count+=1;
        }
        count+=pathSumFrom(root.left, targetSum-root.val);
        count+=pathSumFrom(root.right, targetSum-root.val);
        return count;
    }

    //分类讨论+dfs
    //深度优先搜索，但要分成2种情况进行搜索：1.含根节点root 2.不含根节点root
    public int pathSum(TreeNode root, int targetSum) {
        if (root==null) {
            return 0;
        }

        return pathSumFrom(root, targetSum)
                + pathSum(root.left, targetSum)
                + pathSum(root.right, targetSum);
    }

    //回溯法计算前缀和，先序遍历
    //这里必须用回溯法，可以保证"前缀"这个特性
    public int dfs(TreeNode root, long targetSum, Map<Long, Integer> prefixSumCounts, long currPrefixSum) {
        if (root==null) {
            return 0;
        }
        currPrefixSum += root.val;

        int count = prefixSumCounts.getOrDefault(currPrefixSum-targetSum, 0);

        prefixSumCounts.put(currPrefixSum, prefixSumCounts.getOrDefault(currPrefixSum, 0)+1);
        count+=dfs(root.left, targetSum, prefixSumCounts, currPrefixSum);
        count+=dfs(root.right, targetSum, prefixSumCounts, currPrefixSum);
        prefixSumCounts.put(currPrefixSum, prefixSumCounts.getOrDefault(currPrefixSum, 0)-1);

        return count;
    }

    //前缀和（回溯法）
    //先计算前缀和，然后 路径和 = 子节点的前缀和 - 父节点的前缀和
    public int pathSum2(TreeNode root, int targetSum) {
        //<前缀和，路径数>
        Map<Long, Integer> prefixSumCounts = new HashMap<>();
        prefixSumCounts.put(0L, 1);
        int count=dfs(root, targetSum, prefixSumCounts, 0);

        return count;
    }

    public static void main(String[] args) {
//        TreeNode a9=new TreeNode(1);
//        TreeNode a8=new TreeNode(-2);
//        TreeNode a7=new TreeNode(3);
//        TreeNode a6=new TreeNode(11);
//        TreeNode a5=new TreeNode(2, null, a9);
//        TreeNode a4=new TreeNode(3, a7, a8);
//        TreeNode a3=new TreeNode(-3, null, a6);
//        TreeNode a2=new TreeNode(5, a4, a5);
//        TreeNode a1=new TreeNode(10, a2, a3);
        TreeNode a1 = CommonUtil.deserializeBinaryTree("10,5,3,3,#,#,-2,#,#,2,#,1,#,#,-3,#,11,#,#");

        System.out.println(new LC437PathSum3().pathSum2(a1, 8));
    }
}
